--- title: Prediction on new data keywords: fastai sidebar: home_sidebar summary: "Notebook to predict the segmentation masks on new data using pretrained or customized models. " description: "Notebook to predict the segmentation masks on new data using pretrained or customized models. " nb_path: "nbs/predict.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
#@markdown Please run this cell to get started.
%load_ext autoreload
%autoreload 2
try:
    from google.colab import files, drive
except ImportError:
    pass
try:
    import deepflash2
except ImportError:
    !pip install -q deepflash2
import zipfile
import imageio
from fastai.vision.all import *
from deepflash2.all import *
{% endraw %}

Provide Data

Required data structure

  • One folder for images (different from training image folder!)

Examplary structure:

  • [folder] images_new
    • [file] 0001.tif
    • [file] 0002.tif

Working on Google Colab (recommended)

This section allows you to upload a zip folder or connect to your Google Drive.

Connect to Google Drive (recommended)

  • The folder in your drive must contain all images in one folder.
  • See here how to organize your files in Google Drive.
  • See this stackoverflow post for browsing files with the file browser
{% raw %}
try:
    drive.mount('/content/drive')
    path = "/content/drive/My Drive/data" #@param {type:"string"}
    path = Path(path)
    #@markdown Example: "/content/drive/My Drive/data"
except:
    print("Warning: Connecting to Google Drive only works on Google Colab.")
    pass
Warning: Connecting to Google Drive only works on Google Colab.
{% endraw %}

Upload a zip file

  • The zip file must contain all images in one folder
  • See here how to zip files on Windows or Mac.
{% raw %}
path = Path('data')
try:
    u_dict = files.upload()
    for key in u_dict.keys():
        unzip(path, key)
except:
    print("Warning: File upload only works on Google Colab.")
    pass
Warning: File upload only works on Google Colab.
{% endraw %}

Local Installation

If you're working on your local machine or server, provide a path to the correct folder.

{% raw %}
path = "my_data" #@param {type:"string"}
path = Path(path)
#@markdown Example: "new_images"
{% endraw %}

Try with sample data

If you don't have any data available yet, try our sample data

{% raw %}
path = Path('sample_data_cFOS')
url = "https://github.com/matjesg/deepflash2/releases/download/model_library/wue1_cFOS_small.zip"
urllib.request.urlretrieve(url, 'sample_data_cFOS.zip')
unzip(path, 'sample_data_cFOS.zip')
{% endraw %}

Load data

{% raw %}
image_folder = "images" #@param {type:"string"}
files = get_image_files(path/image_folder)
#@markdown Number of classes: e.g., 2 for binary segmentation (foreground and background class)
n_classes = 2 #@param {type:"integer"}
ds = TileDataset(files, n_classes=n_classes)
{% endraw %} {% raw %}
#@markdown Use the slider to control the number of displayed images
first_n = 3 #@param {type:"slider", min:1, max:100, step:1}
ds.show_data(max_n=first_n, figsize=(5,5), overlay=False)
{% endraw %}

Model Defintion

Select model model architecture

{% raw %}
model_arch = 'unet_deepflash2' #@param ["unet_deepflash2",  "unet_falk2019", "unet_ronnberger2015"]
n_channels = ds.get_data(max_n=1)[0].shape[-1]
model = torch.hub.load('matjesg/deepflash2', model_arch, pretrained=False, n_classes=ds.c, in_channels=n_channels, force_reload=True)
model_list = L()
Downloading: "https://github.com/matjesg/deepflash2/archive/master.zip" to /home/firstusr/.cache/torch/hub/master.zip
{% endraw %}

Select pretraind model weights

Using customized weights (recommended)

{% raw %}
#@markdown Models should be saved in the 'models' folder of your provided path.
models_folder = "models" #@param {type:"string"}
model_list = get_files(path/models_folder, extensions='.pth')
print('Found models', model_list)
Found models (#2) [Path('sample_data_cFOS/models/model1.pth'),Path('sample_data_cFOS/models/model2.pth')]
{% endraw %} {% raw %}
try:
    u_dict = files.upload()
    model_list += [Path(u) for u in u_dict]
except:
    print("Warning: File upload only works on Google Colab.")
    pass
print('Found models', model_list)
Warning: File upload only works on Google Colab.
Found models (#2) [Path('sample_data_cFOS/models/model1.pth'),Path('sample_data_cFOS/models/model2.pth')]
{% endraw %}

Using pretrained weights from deepflash2

{% raw %}
pretrained_weights = "wue1_cFOS" #@param ["cFOS", "Parv"]
model = torch.hub.load('matjesg/deepflash2', model_arch, pretrained=True, dataset=pretrained_weights, n_classes=ds.c, in_channels=n_channels)
model_list = L(Path(f'/home/firstusr/.cache/torch/hub/checkpoints/{pretrained_weights}.pth'))
Using cache found in /home/firstusr/.cache/torch/hub/matjesg_deepflash2_master
{% endraw %}

Prediction

{% raw %}
res, res_mc = {}, {}
for m in progress_bar(model_list):
    print(f'Model {m.stem}')
    dls = DataLoaders.from_dsets(ds, batch_size=4 ,shuffle=False, drop_last=False)
    state_dict = torch.load(m)
    model.load_state_dict(state_dict, strict=False)
    if torch.cuda.is_available(): dls.cuda(), model.cuda()
    learn = Learner(dls, model, loss_func=0)#.to_fp16()
    
    print(f'Predicting segmentation masks')
    smxs, segs, _ = learn.predict_tiles(dl=dls.train)   
    print(f'Predicting uncertainty maps')
    smxs_mc, segs_mc, std = learn.predict_tiles(dl=dls.train, mc_dropout=True, n_times=10)
    
    #TODO Save results not using RAM
    for i, file in enumerate(files):
        res[(m.stem, file)] = smxs[i], segs[i]
        res_mc[(m.stem, file)] = smxs_mc[i], segs_mc[i], std[i]
100.00% [1/1 00:15<00:00]
Model wue1_cFOS
Predicting segmentation masks
Predicting uncertainty maps
{% endraw %}

Ensembling

Here you can validate your results. If you choose to only train one model (n_models = 1), ensemble and model results will be the same.

{% raw %}
pred_dir = 'preds' #@param {type:"string"}
pred_path = path/pred_dir/'ensemble'
pred_path.mkdir(parents=True, exist_ok=True)
uncertainty_dir = 'uncertainties' #@param {type:"string"}
uncertainty_path = path/uncertainty_dir/'ensemble'
uncertainty_path.mkdir(parents=True, exist_ok=True)

#@markdown Define `filetype` to save the predictions and uncertainties. All common filetypes are supported.
filetype = 'png' #@param {type:"string"}
{% endraw %} {% raw %}
res_list = []
for file in files:
    img = ds.get_data(file)[0]
    msk = ds.get_data(file, mask=True)[0]
    pred = ensemble_results(res, file)
    pred_std = ensemble_results(res_mc, file, std=True)
    df_tmp = pd.Series({'file' : file.name, 'entropy': mean_entropy(pred_std)})
    plot_results(img, pred, pred_std, df=df_tmp)
    res_list.append(df_tmp)
    imageio.imsave(pred_path/f'{file.name}_pred.{filetype}', pred.astype(np.uint8) if np.max(pred)>1 else pred.astype(np.uint8)*255)
    imageio.imsave(uncertainty_path/f'{file.name}_uncertainty.{filetype}', pred_std.astype(np.uint8)*255)
df_res = pd.DataFrame(res_list)
df_res.to_csv(path/'ensemble_results.csv', index=False)
{% endraw %} {% raw %}
model_number = 1 #@param {type:"slider", min:1, max:5, step:1}
model_name = model_list[model_number-1].stem
pred_path = path/pred_dir/model_name
pred_path.mkdir(parents=True, exist_ok=True)
uncertainty_path = path/uncertainty_dir/model_name
uncertainty_path.mkdir(parents=True, exist_ok=True)
res_list = []
for file in files:
    img = ds.get_data(file)[0]
    pred = res[(model_name,file)][1]
    pred_std = res_mc[(model_name,file)][2][...,0]
    df_tmp = pd.Series({'file' : file.name, 'entropy': mean_entropy(pred_std)})
    plot_results(img, pred, pred_std, df=df_tmp)
    res_list.append(df_tmp)
    imageio.imsave(pred_path/f'{file.name}_pred.{filetype}', pred.astype(np.uint8) if np.max(pred)>1 else pred.astype(np.uint8)*255)
    imageio.imsave(uncertainty_path/f'{file.name}_uncertainty.{filetype}', pred_std.astype(np.uint8)*255)
df_res = pd.DataFrame(res_list)
df_res.to_csv(path/f'{model_name}_results.csv', index=False)
{% endraw %}

Download Section

To download validation predictions and uncertainties, you first need to execute Section Validate models and ensembles.

Note: If you're connected to Google Drive, the models are automatically saved to your drive.

{% raw %}
out_name = 'predictions.zip'
with zipfile.ZipFile(path/out_name, 'w') as zf:
    for f in get_image_files(path/pred_dir):
          zf.write(f)
try:
    files.download(path/out_name)
except:
    print("Warning: File download only works on Google Colab.")
    pass
Warning: File download only works on Google Colab.
{% endraw %} {% raw %}
out_name = 'uncertainties.zip'
with zipfile.ZipFile(path/out_name, 'w') as zf:
    for f in get_image_files(path/pred_dir):
          zf.write(f)
try:
    files.download(path/out_name)
except:
    print("Warning: File download only works on Google Colab.")
    pass
Warning: File download only works on Google Colab.
{% endraw %} {% raw %}
out_name = 'results.zip'
with zipfile.ZipFile(path/out_name, 'w') as zf:
    for f in get_files(path, extensions='.csv'):
          zf.write(f)
try:
    files.download(path/out_name)
except:
    print("Warning: File download only works on Google Colab.")
    pass
Warning: File download only works on Google Colab.
{% endraw %}